Learning the Lennard-Jones (LJ) Potential¶
This notebook demonstrates how to learn a pair potential (e.g., Lennard-Jones) from Molecular Dynamics (MD) simulations by optimizing against Radial Distribution Functions (RDFs) and Velocity Auto-Correlation Functions (VACFs).
Key Concepts¶
- Lennard-Jones Potential: Describes interactions between two particles:
$$ U(r) = 4 \varepsilon \left[ \left( \frac{\sigma}{r} \right)^{12} - \left( \frac{\sigma}{r} \right)^6 \right] $$
- MD Simulations: Use Nose-Hoover Chain thermostat to evolve the system.
- Objective: Learn a potential function that reproduces the RDF and VACF of a reference system.
Importing Required Libraries¶
import os
import json
import numpy as np
import matplotlib.pyplot as plt
from ase import Atoms
from ase.lattice.cubic import FaceCenteredCubic, Diamond
from ase.visualize import *
from data_src.data import *
from potential_src.pairMLP.potential_PairMLP import *
from observables.rdf import *
from observables.observers import *
from utils.get_utils import *
def plot_vacf(vacf_sim, vacf_target, fn, path, dt=0.01, save_data=False):
t_range = np.linspace(0.0, vacf_sim.shape[0], vacf_sim.shape[0]) * dt
plt.plot(t_range, vacf_sim, label='simulation', linewidth=4, alpha=0.6, )
if vacf_target is not None:
plt.plot(t_range, vacf_target, label='target', linewidth=2,linestyle='--', c='black' )
plt.legend()
plt.show()
if save_data:
np.savetxt(path + '/vacf_{}.txt'.format(fn), np.stack((t_range, vacf_sim)), delimiter=',' )
np.savetxt(path + '/vacf_{}_target.txt'.format(fn), np.stack((t_range, vacf_target)), delimiter=',' )
plt.savefig(path + '/vacf_{}.pdf'.format(fn), bbox_inches='tight')
plt.close()
def plot_rdf( g_sim, rdf_target, fn, path, start, nbins, save_data=False, end=2.5):
bins = np.linspace(start, end, nbins)
plt.plot(bins, g_sim , label='simulation', linewidth=4, alpha=0.6)
plt.plot(bins, rdf_target , label='target', linewidth=2,linestyle='--', c='black')
plt.xlabel("$\AA$")
plt.ylabel("g(r)")
if save_data:
np.savetxt(path + '/rdf_{}.txt'.format(fn), np.stack((bins, g_sim)), delimiter=',' )
np.savetxt(path + '/rdf_{}_target.txt'.format(fn), np.stack((bins, rdf_target)), delimiter=',' )
plt.show()
plt.savefig(path + '/rdf_{}.pdf'.format(fn), bbox_inches='tight')
plt.close()
def plot_pair(fn, path, model, prior, device, end=2.5, target_pot=None):
if target_pot is None:
target_pot = LennardJones(1.0, 1.0)
else:
target_pot = target_pot.to("cpu")
x = torch.linspace(0.1, end, 250)[:, None].to(device)
u_fit = (model(x) + prior(x)).detach().cpu().numpy()
u_fit = u_fit - u_fit[-1]
u_target = target_pot(x.detach().cpu()).squeeze()
plt.plot( x.detach().cpu().numpy(),
u_fit,
label='fit', linewidth=4, alpha=0.6)
plt.plot( x.detach().cpu().numpy(),
u_target.detach().cpu().numpy(),
label='truth',
linewidth=2,linestyle='--', c='black')
plt.ylim(-2, 4.0)
plt.legend()
plt.show()
plt.savefig(path + '/potential_{}.jpg'.format(fn), bbox_inches='tight')
plt.close()
return u_fit
System Definition¶
We first define the simulation parameters, such as density, temperature, and the cutoff radius. We also initialize the MD system using ASE (Atomic Simulation Environment).
pair_data_dict = { 'lj_0.3_1.2': {
'rdf_fn': '../data/LJ_data/rdf_rho0.3_T1.2_dt0.01.csv' ,
'vacf_fn': '../data/LJ_data/vacf_rho0.3_T1.2_dt0.01.csv' ,
'rho': 0.3,
'T': 1.2,
'start': 0.75,
'end': 3.3,
'element': "H",
'mass': 1.0,
"N_unitcell": 4,
"cell": FaceCenteredCubic,
"target_pot": LennardJones()
},
}
width_dict = {'tiny': 64,
'low': 128,
'mid': 256,
'high': 512}
gaussian_dict = {'tiny': 16,
'low': 32,
'mid': 64,
'high': 128}
The simulation in this code is self-consistent, meaning all computations are handled within the script without relying on external molecular dynamics software like LAMMPS or GROMACS. Everything—potential evaluation, numerical integration, observables extraction, and model training—is done internally using PyTorch, ASE, and custom simulation modules.
The function 'get_system()' initializes a periodic atomic system in 3D using ASE. We need to define:
- density
- temperature
- lattice type (e.g. FCC)
- PBCs
def get_system(data_str, device, size):
# here we define an NVT system setup
rho = pair_data_dict[data_str]['rho']
T = pair_data_dict[data_str]['T']
dim = pair_data_dict[data_str].get("dim", 3)
if dim == 3:
# initialize states with ASE
cell_module = pair_data_dict[data_str]['cell']
N_unitcell = pair_data_dict[data_str]['N_unitcell']
def get_unit_len(rho, N_unitcell):
L = (N_unitcell / rho) ** (1/3)
return L
L = get_unit_len(rho, N_unitcell)
print("lattice param:", L)
atoms = cell_module(symbol=pair_data_dict[data_str]['element'],
size=(size, size, size),
latticeconstant= L,
pbc=True)
system = System(atoms, device=device)
system.set_temperature(T)
return system
The system uses a ground truth potential (target_pot) from pair_data_dict. This potential can be: -Lennard-Jones (LJ) for simple atomic interactions. -Custom pair potentials for more complex interactions. This potential is implemented in PyTorch, meaning forces and energies are computed dynamically without external MD engines.
The function get_target_obs() sets up and runs the simulation.
def get_target_obs(system, data_str, n_sim, rdf_range, nbins, t_range, dt, skip=25):
print("simulating {}".format(data_str))
device = system.device
# simulation setup
target_pot = pair_data_dict[data_str]['target_pot'] # take target potential
T = pair_data_dict[data_str]['T'] # take target temperature
pot = PairPotentials(system, target_pot, cutoff=2.5, nbr_list_device=device).to(device)
diffeq = NoseHooverChain(pot,
system,
Q=50.0,
T=T,
num_chains=5,
adjoint=True,
topology_update_freq=1).to(system.device)
# define simulator with
sim = Simulations(system, diffeq)
# define objects for the observables
rdf_obs = rdf(system, nbins=nbins, r_range=rdf_range)
vacf_obs = vacf(system, t_range=t_range)
all_vacf_sim = []
# Run MD Simulations & Extract Target Data
for i in range(n_sim):
v_t, q_t, pv_t = sim.simulate(100, dt=dt, frequency=100) # 100 time steps, dt is given with argparse are in fs, freq is how often the system's state is updated and logged during a md simulation.
if i >= skip:
vacf_sim = vacf_obs(v_t).detach().cpu().numpy()
all_vacf_sim.append(vacf_sim)
# loop over to compute observables
trajs = torch.Tensor( np.stack( sim.log['positions'])).to(system.device).detach()
all_g_sim = []
for i in range(len(trajs)):
if i >= skip:
_, _, g_sim = rdf_obs(trajs[[i]])
all_g_sim.append(g_sim.detach().cpu().numpy())
all_g_sim = np.array(all_g_sim).mean(0)
all_vacf_sim = np.array(all_vacf_sim).mean(0)
return all_g_sim, all_vacf_sim
def get_observer(system, data_str, nbins, t_range, rdf_start):
# get dt
dt = pair_data_dict[data_str].get('dt', 0.01)
rdf_end = pair_data_dict[data_str].get("end", None)
xnew = np.linspace(rdf_start , rdf_end, nbins)
# initialize observable function
obs = rdf(system, nbins, (rdf_start , rdf_end) )
vacf_obs = vacf(system, t_range=t_range)
# get experimental rdf
dim = pair_data_dict[data_str].get("dim", 3)
rdf_data_path = pair_data_dict[data_str].get("fn", None)
# generate simulated data
if not rdf_data_path:
rdf_data, vacf_target = get_target_obs(system, data_str, 200, (rdf_start, rdf_end), nbins=nbins, t_range=t_range, skip=50, dt=dt)
vacf_target = torch.Tensor(vacf_target).to(system.device)
rdf_data = np.vstack( (np.linspace(rdf_start, rdf_end, nbins), rdf_data))
else:
# experimental rdfs
rdf_data = np.loadtxt(rdf_data_path, delimiter=',')
vacf_target = None
_, rdf_target = get_exp_rdf(rdf_data, nbins, (rdf_start, rdf_end), obs.device, dim=dim)
# get model potential and simulate
return xnew, rdf_target, obs, vacf_target, vacf_obs
def get_sim(system, model, data_str, topology_update_freq=1):
T = pair_data_dict[data_str]['T']
diffeq = NoseHooverChain(model,
system,
Q=50.0,
T=T,
num_chains=5,
adjoint=True,
topology_update_freq=topology_update_freq).to(system.device)
# define simulator with
sim = Simulations(system, diffeq)
return sim
Define Training Hyperparameters¶
– Set learning rate, epochs, and training options.
params = {
'val': []
}
sys_params = {
'val': params['val'],
}
n_epochs = 1000
n_sim = 200
size = 4
cutoff = 2.5
t_range = 50
device = 'cpu'
nbins = 100
tau = 60 #assignments['opt_freq']
rdf_start = 0.75
skip = 1
nbr_list_device = 'cpu'
topology_update_freq = 1
data_str_list = ['lj_0.3_1.2']
# Get the grounth truth pair potentials
target_pot = LennardJones()
# merge paramset a
if sys_params['val']:
val_str_list = sys_params['val']
else:
val_str_list = []
project_name = 'lj_sim'
suggestion_id = 'run0'
model_path = '{}/{}'.format(project_name, suggestion_id)
if not os.path.exists(model_path):
os.makedirs(model_path)
print("Training for {} epochs".format(n_epochs))
train_vacf = True
Training for 1000 epochs
Initialize the MD System¶
– Set up atoms, boundary conditions, and simulation parameters
system_list = []
for data_str in data_str_list+val_str_list:
system = get_system(data_str, device, size)
system_list.append(system)
lattice param: 2.371262202993375
/Users/francescapagano/miniconda3/envs/myenv_mlmd/lib/python3.9/site-packages/ase/md/md.py:52: FutureWarning: Specify the temperature in K using the 'temperature_K' argument warnings.warn(FutureWarning(w))
what is system ?¶
System is the starting point for the simulation:
from ase.visualize import view
view(system_list[0])
<Popen: returncode: None args: ['/Users/francescapagano/miniconda3/envs/myen...>
2025-01-29 14:15:52.142 python[80563:3483549] +[IMKClient subclass]: chose IMKClient_Legacy 2025-01-29 14:15:52.142 python[80563:3483549] +[IMKInputSession subclass]: chose IMKInputSession_Legacy
Potential¶
- LJ potential is the physics-informed prior potential
- The neural network model (pairMLP) learns corrections to the prior potential
class LJFamily(torch.nn.Module): def init(self, sigma=1.0, epsilon=1.0, attr_pow=6, rep_pow=12): super(LJFamily, self).init() self.sigma = torch.nn.Parameter(torch.Tensor([sigma])) self.epsilon = torch.nn.Parameter(torch.Tensor([epsilon])) self.attr_pow = attr_pow self.rep_pow = rep_pow
def LJ(self, r, sigma, epsilon):
return 4 * epsilon * ((sigma/r)**self.rep_pow - (sigma/r)**self.attr_pow)
def forward(self, x):
return self.LJ(x, self.sigma, self.epsilon)
mlp_parmas = {'n_gauss': int(cutoff//0.10), # cutoff//gaussian_width 'r_start': 0.0, 'r_end': cutoff, 'n_width': 128, 'n_layers': 3, 'nonlinear': 'ELU'} lj_params = {'epsilon': 0.4, 'sigma': 0.9, "power": 10}
NN = pairMLP(mlp_parmas) pair = LJFamily(epsilon=2.0, sigma=0.9, rep_pow=6, attr_pow=3) # ExcludedVolume(lj_params)
model_list = [] for i, data_str in enumerate(data_str_list + val_str_list):
pairNN = PairPotentials(system_list[i], NN,
cutoff=cutoff,
nbr_list_device=nbr_list_device
).to(device)
prior = PairPotentials(system_list[i], pair,
cutoff=2.5,
nbr_list_device=nbr_list_device
).to(device)
model = Stack({'pairnn': pairNN, 'pair': prior})
model_list.append(model)
Stack combines:
- A learned NN potential (pairnn)
- A predefined Lennard-Jones potential (pair) The model will sum the outputs of pairnn and pair to compute the total energy.
Data Flow¶
Atomic Positions → Compute Pairwise Distances (r_ij) →
│
├──> Lennard-Jones Potential (pair) → U_LJ(r_ij)
│
├──> Neural Network Potential (pairnn)
│ ├── GaussianSmearing(r_ij)
│ ├── Feedforward NN (Linear + ELU Layers)
│ └── Output: U_MLP(r_ij)
│
└──> Final Potential: U_total = U_LJ + U_MLP → Sum Over Pairs → Compute System Energy
│
└──> Use for Molecular Dynamics (Forces, Simulations, Training)
def print_model_summary(model):
print("Model Structure:")
print("---------------")
for name, submodel in model.models.items():
print(f"\n{name}:")
print(" Network architecture:")
for child_name, child in submodel.named_children():
print(f" {child_name}: {child}")
if hasattr(submodel, 'model'):
print("\n Potential network:")
for layer_name, layer in submodel.model.named_children():
print(f" {layer_name}: {layer}")
# Use it on your model
print_model_summary(model)
Model Structure:
---------------
pairnn:
Network architecture:
model: pairMLP(
(smear): GaussianSmearing()
(layers): ModuleList(
(0): Linear(in_features=24, out_features=24, bias=True)
(1): ELU(alpha=1.0)
(2): Linear(in_features=24, out_features=128, bias=True)
(3): ELU(alpha=1.0)
(4): Linear(in_features=128, out_features=128, bias=True)
(5): ELU(alpha=1.0)
(6): Linear(in_features=128, out_features=128, bias=True)
(7): ELU(alpha=1.0)
(8): Linear(in_features=128, out_features=128, bias=True)
(9): ELU(alpha=1.0)
(10): Linear(in_features=128, out_features=24, bias=True)
(11): ELU(alpha=1.0)
(12): Linear(in_features=24, out_features=1, bias=True)
)
)
Potential network:
smear: GaussianSmearing()
layers: ModuleList(
(0): Linear(in_features=24, out_features=24, bias=True)
(1): ELU(alpha=1.0)
(2): Linear(in_features=24, out_features=128, bias=True)
(3): ELU(alpha=1.0)
(4): Linear(in_features=128, out_features=128, bias=True)
(5): ELU(alpha=1.0)
(6): Linear(in_features=128, out_features=128, bias=True)
(7): ELU(alpha=1.0)
(8): Linear(in_features=128, out_features=128, bias=True)
(9): ELU(alpha=1.0)
(10): Linear(in_features=128, out_features=24, bias=True)
(11): ELU(alpha=1.0)
(12): Linear(in_features=24, out_features=1, bias=True)
)
pair:
Network architecture:
model: LJFamily()
Potential network:
def count_parameters(model):
total_params = 0
trainable_params = 0
for name, param in model.named_parameters():
param_count = param.numel()
total_params += param_count
if param.requires_grad:
trainable_params += param_count
print(f"{name}: {list(param.shape)} ({param_count} parameters)")
print(f"\nTotal parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")
count_parameters(model)
models.pairnn.model.smear.width: [24] (24 parameters) models.pairnn.model.smear.offsets: [24] (24 parameters) models.pairnn.model.layers.0.weight: [24, 24] (576 parameters) models.pairnn.model.layers.0.bias: [24] (24 parameters) models.pairnn.model.layers.2.weight: [128, 24] (3072 parameters) models.pairnn.model.layers.2.bias: [128] (128 parameters) models.pairnn.model.layers.4.weight: [128, 128] (16384 parameters) models.pairnn.model.layers.4.bias: [128] (128 parameters) models.pairnn.model.layers.6.weight: [128, 128] (16384 parameters) models.pairnn.model.layers.6.bias: [128] (128 parameters) models.pairnn.model.layers.8.weight: [128, 128] (16384 parameters) models.pairnn.model.layers.8.bias: [128] (128 parameters) models.pairnn.model.layers.10.weight: [24, 128] (3072 parameters) models.pairnn.model.layers.10.bias: [24] (24 parameters) models.pairnn.model.layers.12.weight: [1, 24] (24 parameters) models.pairnn.model.layers.12.bias: [1] (1 parameters) models.pair.model.sigma: [1] (1 parameters) models.pair.model.epsilon: [1] (1 parameters) Total parameters: 56507 Trainable parameters: 56507
sim_list = [get_sim(system_list[i],
model_list[i],
data_str,
topology_update_freq=topology_update_freq) for i, data_str in enumerate(data_str_list + val_str_list)]
sim_list
[<md.sim.Simulations at 0x128295b50>]
This cell initializes and collects observables for each molecular system in data_str_list + val_str_list. The goal is to compute and store Radial Distribution Function (RDF) and Velocity Auto-Correlation Function (VACF), both for sim and target.
rdf_obs_list = [] #Stores functions to compute RDF during training.
vacf_obs_list = [] #Stores functions to compute VACF during training
rdf_target_list = [] #Stores ground truth RDF data for comparison.
vacf_target_list = [] #Stores ground truth VACF data for comparison.
rdf_bins_list = [] #Stores distance bins for RDF computation
for i, data_str in enumerate(data_str_list + val_str_list):
rdf_start = pair_data_dict[data_str].get("start", 0.75)
x, rdf_target, rdf_obs, vacf_target, vacf_obs = get_observer(system_list[i],
data_str,
nbins,
t_range=t_range,
rdf_start=rdf_start)
rdf_bins_list.append(x)
rdf_obs_list.append(rdf_obs)
rdf_target_list.append(rdf_target)
vacf_obs_list.append(vacf_obs)
vacf_target_list.append(vacf_target)
simulating lj_0.3_1.2
import torch
import matplotlib.pyplot as plt
# Convert tensor to NumPy array
rdf_target = rdf_target_list[0].cpu().numpy() # Assuming first entry is needed
plt.figure(figsize=(8, 5))
plt.plot(rdf_bins_list[0], rdf_target, label="RDF Target", color='b', linewidth=2)
plt.xlabel("Distance (Å)")
plt.ylabel("g(r)")
plt.title("Radial Distribution Function (RDF)")
plt.legend()
plt.grid()
plt.show()
import torch
import matplotlib.pyplot as plt
time_bins = torch.arange(len(vacf_target)).cpu().numpy() * 0.005 # Assuming dt is known
# Convert tensor to NumPy array
vacf_target = vacf_target_list[0].cpu().numpy() # Assuming first entry is needed
plt.figure(figsize=(8, 5))
plt.plot(time_bins,vacf_target, label="VACF Target", color='b', linewidth=2)
plt.xlabel("Time (fs)")
plt.ylabel("VACF")
plt.title("Velocity Auto-Correlation Function (VACF)")
plt.legend()
plt.grid()
plt.show()
Initialize Training Components¶
optimizer = torch.optim.Adam(list(NN.parameters()), lr=0.002)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
'min',
min_lr=1e-6,
verbose=True, factor = 0.5, patience= 20,
threshold=5e-5)
# Set up simulations
loss_log = []
#
obs_log = dict()
/Users/francescapagano/miniconda3/envs/myenv_mlmd/lib/python3.9/site-packages/torch/optim/lr_scheduler.py:62: UserWarning: The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate. warnings.warn(
Precompute Observables for Training¶
for i, data_str in enumerate(data_str_list + val_str_list):
obs_log[data_str] = {}
obs_log[data_str]['rdf'] = []
obs_log[data_str]['vacf'] = []
Training¶
rdf_weight = 0.95
vacf_weight = 0.0
for i in range(n_epochs):
loss_rdf = torch.Tensor([0.0]).to(device)
loss_vacf = torch.Tensor([0.0]).to(device)
n_train = len(data_str_list)
for j, sim in enumerate(sim_list[:n_train]):
data_str = (data_str_list + val_str_list)[j]
# get dt
dt = 0.005
# Simulate
v_t, q_t, pv_t = sim.simulate(steps=tau, frequency=tau, dt=dt)
if data_str in val_str_list:
v_t = v_t.detach()
q_t = q_t.detach()
pv_t = pv_t.detach()
if torch.isnan(q_t.reshape(-1)).sum().item() > 0:
print("encounter NaN")
print( 5 - (i / n_epochs) * 5 )
break
#_, _, g_sim = rdf_obs_list[j](q_t[::skip])
# save memory by computing it in serial
skip = 5
n_frames = q_t[::skip].shape[0]
for idx in range(n_frames):
if idx == 0:
_, _, g_sim = rdf_obs_list[j](q_t[::skip][[idx]])
else:
g_sim += rdf_obs_list[j](q_t[::skip][[idx]])[2]
g_sim = g_sim / n_frames
# compute vacf
vacf_sim = vacf_obs_list[j](v_t)
if data_str in data_str_list:
if vacf_target_list[j] is not None:
loss_vacf += (vacf_sim - vacf_target_list[j][:t_range]).pow(2).mean()
else:
loss_vacf += 0.0
drdf = g_sim - rdf_target_list[j]
loss_rdf += (drdf).pow(2).mean()#+ JS_rdf(g_sim, rdf_target_list[j])
obs_log[data_str]['rdf'].append(g_sim.detach().cpu().numpy())
obs_log[data_str]['vacf'].append(vacf_sim.detach().cpu().numpy())
if i % 5 ==0 :
if vacf_target_list[j] is not None:
vacf_target = vacf_target_list[j][:t_range].detach().cpu().numpy()
else:
vacf_target = None
rdf_target = rdf_target_list[j].detach().cpu().numpy()
plot_vacf(vacf_sim.detach().cpu().numpy(), vacf_target,
fn=data_str + "_{}".format(str(i).zfill(3)),
dt=dt,
path=model_path)
plot_rdf(g_sim.detach().cpu().numpy(), rdf_target,
fn=data_str + "_{}".format(str(i).zfill(3)),
path=model_path,
start=rdf_start,
nbins=nbins,
end=rdf_obs_list[j].r_axis[-1])
if i % 5 ==0 :
potential = plot_pair( path=model_path,
fn=str(i).zfill(3),
model=sim.integrator.model.models['pairnn'].model,
prior=sim.integrator.model.models['pair'].model,
device=device,
target_pot=target_pot.to(device),
end=cutoff)
if train_vacf == "True":
loss = rdf_weight * loss_rdf + vacf_weight * loss_vacf
else:
loss = rdf_weight * loss_rdf
loss.backward()
optimizer.step()
optimizer.zero_grad()
print(loss_vacf.item(), loss_rdf.item())
scheduler.step(loss)
loss_log.append([loss_vacf.item(), loss_rdf.item() ])
current_lr = optimizer.param_groups[0]["lr"]
if current_lr <= 1e-5:
print("training converged")
break
np.savetxt(model_path + '/loss.txt', np.array(loss_log), delimiter=',')
# # save potentials
# if np.array(loss_log)[-10:, 1].mean() <= 0.005:
# np.savetxt(model_path + '/potential.txt', potential, delimiter=',')
0.04633055999875069 0.010043300688266754 0.20236946642398834 0.05829494819045067 0.13431745767593384 0.11968377977609634 0.06419174373149872 0.16333819925785065 0.038015469908714294 0.14341723918914795
0.02930583991110325 0.09158311784267426 0.04251861944794655 0.090910904109478 0.03419381007552147 0.14977754652500153 0.01796233467757702 0.15637627243995667 0.01547524705529213 0.09807882457971573
0.04618494585156441 0.05017462372779846 0.11627700179815292 0.021105671301484108 0.043983835726976395 0.06279174983501434 0.030113138258457184 0.1281590461730957 0.029843416064977646 0.1573271006345749
0.019405346363782883 0.227510467171669 0.024593450129032135 0.2552144527435303 0.018196361139416695 0.29571297764778137 0.02855883166193962 0.2951563894748688 0.021850736811757088 0.2793882489204407
0.038046326488256454 0.27704596519470215 0.02740645594894886 0.18910230696201324 0.017252592369914055 0.1244056448340416 0.020001379773020744 0.07963059842586517 0.011261511594057083 0.062042560428380966
0.012782521545886993 0.032121773809194565 0.02184639871120453 0.029230860993266106 0.032829657196998596 0.021662866696715355 0.06968507915735245 0.030158495530486107 0.043725598603487015 0.04170085862278938
0.02225402556359768 0.03344520181417465 0.0833972617983818 0.02449600212275982 0.03217475116252899 0.018466275185346603 0.03851333260536194 0.021241605281829834 0.12616121768951416 0.015991264954209328
0.0511041060090065 0.009908926673233509 0.03359217569231987 0.014169629663228989 0.0743744894862175 0.004685154650360346 0.1356428861618042 0.012889456003904343 0.09922919422388077 0.008808398619294167
0.05750194191932678 0.010880760848522186 0.05852023512125015 0.010412020608782768 0.08340760320425034 0.005074560642242432 0.05566220358014107 0.008761439472436905 0.03940587118268013 0.010030992329120636
0.06671109795570374 0.010998988524079323 0.09722784161567688 0.013169880025088787 0.04666656628251076 0.014039542526006699 0.0266016386449337 0.00906712468713522 0.060235146433115005 0.0075272200629115105
0.09385934472084045 0.014995615929365158 0.051053486764431 0.007780490908771753 0.05436474829912186 0.004648578353226185 0.06913043558597565 0.006592984311282635 0.051048196852207184 0.005015496630221605
0.05617053434252739 0.003883780911564827 0.06154339015483856 0.007360624615103006 0.012405047193169594 0.007416481152176857 0.0717056542634964 0.004207460209727287 0.0799887627363205 0.004569487180560827
0.06424430012702942 0.0055542862974107265 0.03803030773997307 0.008251429535448551 0.06833816319704056 0.006393061950802803 0.0764475092291832 0.003876088885590434 0.044118452817201614 0.005164355039596558
0.06917167454957962 0.003544821171090007 0.10460229963064194 0.004990475717931986 0.061822351068258286 0.006077866535633802 0.036580268293619156 0.006292472593486309 0.07136891782283783 0.004308749921619892
0.1084817498922348 0.0019340902799740434 0.05874348059296608 0.0037314030341804028 0.05043186992406845 0.0024738553911447525 0.09408694505691528 0.0021624243818223476 0.09122275561094284 0.0038339602760970592
0.030772924423217773 0.008911254815757275 0.034080762416124344 0.005568155087530613 0.08209067583084106 0.0024854487273842096 0.0682661160826683 0.003161157015711069 0.044390853494405746 0.0038394152652472258
0.07547518610954285 0.001764682587236166 0.07235253602266312 0.004831680096685886 0.09986728429794312 0.004434170667082071 0.04135961830615997 0.0032481388188898563 0.09872177243232727 0.0033892521169036627
0.07601369172334671 0.0022193407639861107 0.08489355444908142 0.0016392340185120702 0.026955999433994293 0.0028773536905646324 0.028381120413541794 0.003647969337180257 0.0642792284488678 0.003932399675250053
0.08144722878932953 0.0026279992889612913 0.03627745434641838 0.003960017580538988 0.09532791376113892 0.002152936765924096 0.08060938119888306 0.002907180693000555 0.029194677248597145 0.0026513785123825073
0.042677298188209534 0.004090467467904091 0.09880457818508148 0.004020046442747116 0.08512777090072632 0.0034153214655816555 0.06658999621868134 0.0025431702379137278 0.03461385518312454 0.004609052557498217
0.05919656157493591 0.0031950033735483885 0.07202740013599396 0.003467089729383588 0.03581404313445091 0.00223873695358634 0.04688747972249985 0.003939190413802862 0.07860332727432251 0.002299464540556073
0.06195599213242531 0.0036474221851676702 0.062103915959596634 0.002677328884601593 0.019062094390392303 0.003090736223384738 0.052140358835458755 0.003462327178567648 0.07102754712104797 0.003356299363076687
0.08360378444194794 0.0033088738564401865 0.03269742801785469 0.0019728736951947212 0.10011973232030869 0.001103093265555799 0.05196038633584976 0.0020374611485749483 0.03210205212235451 0.003018932417035103
0.06971368938684464 0.001889049424789846 0.08849235624074936 0.002132921712473035 0.047778040170669556 0.001716303057037294 0.019058817997574806 0.0019802204333245754 0.07318150252103806 0.0015941888559609652
0.06338120996952057 0.0025376901030540466 0.06087113544344902 0.003344047348946333 0.05560697987675667 0.0024015139788389206 0.03515675663948059 0.002795663196593523 0.02359706163406372 0.004246007651090622
0.06295900046825409 0.0037083528004586697 0.024573376402258873 0.0038052713498473167 0.023760592564940453 0.005958267953246832 0.051972780376672745 0.0034134129527956247 0.0634133592247963 0.004205169156193733
0.03414866328239441 0.004493663087487221 0.04584605246782303 0.003998476546257734 0.01875460147857666 0.0039401911199092865 0.0831926092505455 0.0030581483151763678 0.09581798315048218 0.0015265141846612096
0.06330931186676025 0.0038221694994717836 0.07825545966625214 0.002808419754728675 0.1282954216003418 0.003078846260905266 0.042410045862197876 0.0025002574548125267 0.018115798011422157 0.003285877639427781
0.03828461095690727 0.0022987229749560356 0.10932434350252151 0.0019829105585813522 0.05933761224150658 0.001307878759689629 0.05370749533176422 0.003916655667126179 0.04825720936059952 0.0021974476985633373
0.08934042602777481 0.0010675211669877172 0.037287574261426926 0.003348633646965027 0.01513825636357069 0.004311643075197935 0.053105417639017105 0.0013432854320853949 0.12009597569704056 0.002277851104736328
0.050818633288145065 0.004724223166704178 0.03100726567208767 0.001874067005701363 0.09133799374103546 0.0011570833157747984 0.057620562613010406 0.001494632801041007 0.06150764599442482 0.0023313150741159916
0.08368604630231857 0.002848237520083785 0.05097782984375954 0.0016704823356121778 0.03992683067917824 0.0010176339419558644 0.025034936144948006 0.002109982306137681 0.02689593844115734 0.002892452524974942
0.07405448704957962 0.0034658554941415787 0.08389942348003387 0.0028659505769610405 0.02560444548726082 0.002557440660893917 0.03423386439681053 0.003203923348337412 0.054591938853263855 0.002834795042872429
0.022573962807655334 0.0034451084211468697 0.05296829715371132 0.0032701469026505947 0.05343877896666527 0.0017493261257186532 0.04474444314837456 0.002562630223110318 0.0907260999083519 0.0015969823580235243
0.020530246198177338 0.003802410326898098 0.024552499875426292 0.0041845072992146015 0.05318091809749603 0.003999505192041397 0.10100594162940979 0.0021584522910416126 0.06413844227790833 0.0015778415836393833
0.04441334307193756 0.002070643939077854 0.04549916088581085 0.005427600350230932 0.06794867664575577 0.003029446816071868 0.019710682332515717 0.002855653176084161 0.042121294885873795 0.005036984104663134
0.08018523454666138 0.005324619356542826 0.09015438705682755 0.00457659550011158 0.056581366807222366 0.004514981992542744 0.03990184888243675 0.005397412925958633 0.03853808715939522 0.0035117920488119125
0.028920775279402733 0.002667040564119816 0.08216548711061478 0.006145902443677187 0.07655035704374313 0.0020489783491939306 0.036234281957149506 0.002626583678647876 0.042886897921562195 0.00234964513219893
0.06521777808666229 0.0015487339114770293 0.07991483062505722 0.0015492331003770232 0.05614198371767998 0.005295565351843834 0.057382117956876755 0.0019163308897987008 0.07571377605199814 0.002422511810436845
0.06951935589313507 0.0031423550099134445 0.09660250693559647 0.0018593923887237906 0.07215885818004608 0.0025094132870435715 0.03378266841173172 0.0032088959123939276 0.06171073019504547 0.0029725623317062855
0.07213623821735382 0.0014286652440205216 0.038524121046066284 0.0022366081830114126 0.04296727105975151 0.0018621280323714018 0.06337431073188782 0.0011346233077347279 0.07453681528568268 0.0016184784471988678
0.027122242376208305 0.003590609412640333 0.06991582363843918 0.004532129503786564 0.04346119984984398 0.003347467165440321 0.024160807952284813 0.0023364608641713858 0.056272830814123154 0.00215146504342556
0.05117037892341614 0.0027367942966520786 0.07315518707036972 0.0031932187266647816 0.04946345463395119 0.0032869549468159676 0.02238074690103531 0.003056967630982399 0.06628426164388657 0.0038786500226706266
0.031161140650510788 0.0028349151834845543 0.013062585145235062 0.0011921291006729007 0.06554070115089417 0.002175277331843972 0.09362849593162537 0.0021663475781679153 0.030331695452332497 0.002279676264151931
0.048832572996616364 0.0023089454043656588 0.03622838109731674 0.0037944025825709105 0.04960097000002861 0.0022481100168079138 0.05980388820171356 0.002241757931187749 0.06918107718229294 0.002916302066296339
0.03470796346664429 0.0018953901017084718 0.04247225448489189 0.0028281821869313717 0.10363960266113281 0.002418776974081993 0.055149246007204056 0.003054825123399496 0.06060966104269028 0.0014402479864656925
0.04435794800519943 0.0018534105038270354 0.0733352079987526 0.0041503701359033585 0.0321279801428318 0.001898439833894372 0.03973294794559479 0.0020321100018918514 0.08100633323192596 0.002132127992808819
0.04649118334054947 0.0037992498837411404 0.020513296127319336 0.002808330114930868 0.05340828374028206 0.0035707640927284956 0.07799831032752991 0.0023940899409353733 0.03868403285741806 0.001460537314414978
0.06184737756848335 0.001812399597838521 0.147007554769516 0.0011774987215176225 0.04845963045954704 0.0019442903576418757 0.026379775255918503 0.002369264140725136 0.06107716262340546 0.001900967094115913
0.07491497695446014 0.0036729336716234684 0.0637986958026886 0.002132187830284238 0.1180073544383049 0.0033266174141317606 0.06005686894059181 0.002241848036646843 0.06713001430034637 0.0026064224075526
0.07080326229333878 0.002727324375882745 0.06491228938102722 0.0025215521454811096 0.05090469494462013 0.0018093108665198088 0.05472536012530327 0.0030128019861876965 0.06709842383861542 0.004451180808246136
0.0735495537519455 0.0019546186085790396 0.05845671519637108 0.0013967320555821061 0.04122907295823097 0.0018729055300354958 0.06740739941596985 0.003240782069042325 0.07806723564863205 0.002915367716923356
0.03435903787612915 0.0025256327353417873 0.030306274071335793 0.001933713210746646 0.09341749548912048 0.002114050555974245 training converged
# # save potentials
# if np.array(loss_log)[-10:, 1].mean() <= 0.005:
# np.savetxt(model_path + '/potential.txt', potential, delimiter=',')
rdf_dev = []
for j, sim in enumerate(sim_list):
#simulate with no optimization
data_str = (data_str_list + val_str_list)[j]
dt = 0.005
all_vacf_sim = []
for i in range(n_sim):
v_t, q_t, pv_t = sim.simulate(steps=tau, frequency=tau, dt=dt)
# compute VACF
vacf_sim = vacf_obs_list[j](v_t).detach().cpu().numpy()
all_vacf_sim.append(vacf_sim)
all_vacf_sim = np.array(all_vacf_sim).mean(0)
trajs = torch.Tensor( np.stack( sim.log['positions'])).to(system.device).detach()
# get targets
if vacf_target_list[j] is not None:
vacf_target = vacf_target_list[j][:t_range].detach().cpu().numpy()
else:
vacf_target = None
rdf_target = rdf_target_list[j].detach().cpu().numpy()
# loop over to ocmpute observables
all_g_sim = []
for i in range(len(trajs)):
_, _, g_sim = rdf_obs_list[j](trajs[[i]])
all_g_sim.append(g_sim.detach().cpu().numpy())
all_g_sim = np.array(all_g_sim).mean(0)
# compute target deviation
if data_str in data_str_list:
drdf = np.abs(all_g_sim - rdf_target_list[j].cpu().numpy()).mean()
rdf_dev.append(drdf)
# plot observables
plot_vacf(all_vacf_sim, vacf_target,
fn=data_str,
path=model_path,
dt=dt,
save_data=True)
plot_rdf(all_g_sim, rdf_target,
fn=data_str,
path=model_path,
start=rdf_start,
nbins=nbins,
save_data=True,
end=rdf_obs_list[j].r_axis[-1])
# rdf_dev = np.abs(all_g_sim - rdf_target).mean()
Evaluation¶
with save traj
# Save the learned potential if the final RDF loss is below threshold
if np.array(loss_log)[-10:, 1].mean() <= 0.005:
potential = plot_pair(path=model_path,
fn="final",
model=sim.integrator.model.models['pairnn'].model,
prior=sim.integrator.model.models['pair'].model,
device=device,
target_pot=target_pot.to(device),
end=cutoff)
np.savetxt(model_path + '/potential.txt', potential, delimiter=',')
rdf_dev = []
import ase
import ase.io
# Function to save trajectory as XYZ file
def save_traj(system, traj, fname, skip=10):
"""
Saves atomic trajectories as an XYZ file for visualization.
Parameters:
system: ASE Atoms object representing the molecular system.
traj: List of atomic positions at different timesteps.
fname: Filename for saving the trajectory.
skip: Interval for saving frames (reduces file size).
"""
atoms_list = []
for i, frame in enumerate(traj):
if i % skip == 0: # Save every 'skip' frames to reduce file size
frame = ase.Atoms(positions=frame, numbers=system.get_atomic_numbers())
atoms_list.append(frame)
ase.io.write(fname, atoms_list) # Save as XYZ file
# Final evaluation loop
for j, sim in enumerate(sim_list):
# Simulate with the trained model (without optimization)
data_str = (data_str_list + val_str_list)[j]
train_traj = sim.log['positions']
# Save training trajectory
save_traj(system_list[j], train_traj, model_path + '/{}_train.xyz'.format(data_str), skip=10)
dt = 0.005
all_vacf_sim = []
for i in range(n_sim):
v_t, q_t, pv_t = sim.simulate(steps=tau, frequency=tau, dt=dt)
# Compute VACF
vacf_sim = vacf_obs_list[j](v_t).detach().cpu().numpy()
all_vacf_sim.append(vacf_sim)
all_vacf_sim = np.array(all_vacf_sim).mean(0)
trajs = torch.Tensor(np.stack(sim.log['positions'])).to(system.device).detach()
# Get target RDF and VACF
if vacf_target_list[j] is not None:
vacf_target = vacf_target_list[j][:t_range].detach().cpu().numpy()
else:
vacf_target = None
rdf_target = rdf_target_list[j].detach().cpu().numpy()
# Compute RDF over all saved trajectories
all_g_sim = []
for i in range(len(trajs)):
_, _, g_sim = rdf_obs_list[j](trajs[[i]])
all_g_sim.append(g_sim.detach().cpu().numpy())
all_g_sim = np.array(all_g_sim).mean(0)
# Compute target deviation
if data_str in data_str_list:
drdf = np.abs(all_g_sim - rdf_target_list[j].cpu().numpy()).mean()
rdf_dev.append(drdf)
# Plot and save observables (VACF and RDF)
plot_vacf(all_vacf_sim, vacf_target,
fn=data_str,
path=model_path,
dt=dt,
save_data=True)
plot_rdf(all_g_sim, rdf_target,
fn=data_str,
path=model_path,
start=rdf_start,
nbins=nbins,
save_data=True,
end=rdf_obs_list[j].r_axis[-1])
# Save final trajectory (inference trajectory)
save_traj(system_list[j], np.stack(sim.log['positions']),
model_path + '/{}_sim.xyz'.format(data_str), skip=1)
np.savetxt(model_path + '/potential.txt', potential, delimiter=',')
np.savetxt(model_path + '/rdf_dev.txt', np.array(rdf_dev), delimiter=',')
# save loss curve
plt.plot(np.array( loss_log)[:, 0], label='vacf', alpha=0.7)
plt.plot(np.array( loss_log)[:, 1], label='rdf', alpha=0.7)
plt.yscale("log")
plt.legend()
plt.savefig(model_path + '/loss.pdf', bbox_inches='tight')
plt.show()
plt.close()
print( np.array(loss_log)[-10:, 1].mean() )
0.0025417785276658834